In [17]:
# ============================================================
# OASIS-2 – Baseline Dementia Classification (Combined Features)
#
# Features used (NO CDR as predictor):
#   ['age', 'mmse', 'educ', 'ses', 'nwbv', 'etiv', 'asf', 'sex_enc']
#
# Label:
#   dementia_label = 1 if Demented or Converted, 0 if Nondemented
#
# Models:
#   - Logistic Regression
#   - Random Forest
#
# Outputs:
#   - Descriptive statistics
#   - ROC / PR curves (RF, test split)
#   - 5-fold CV metrics for RF (main results)
#   - SHAP summaries for RF (combined model)
#   - Summary text file: oasis_dementia_combined_summary.txt
# ============================================================

import os, warnings
warnings.filterwarnings("ignore")

# Install shap if needed
try:
    import shap  # noqa
except ImportError:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "shap"])
    import shap  # noqa

# Fix deprecated numpy aliases for shap/sklearn internals
import numpy as np
if not hasattr(np, "bool"):
    np.bool = bool
if not hasattr(np, "float"):
    np.float = float

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from IPython.display import display

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    roc_curve,
    precision_recall_curve,
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
)
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

sns.set(style="whitegrid", context="talk")
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

# ------------------------------------------------------------
# 1. Load dataset
# ------------------------------------------------------------
file_name = "oasis_longitudinal_demographics-8d83e569fa2e2d30.xlsx"

if not os.path.exists(file_name):
    raise FileNotFoundError(
        f"Could not find file: {file_name}. "
        "Upload it in the Colab Files pane (left sidebar)."
    )

df = pd.read_excel(file_name)

print("Raw shape:", df.shape)
print("Columns:", df.columns.tolist())
display(df.head())

# ------------------------------------------------------------
# 2. Basic cleaning and type handling
# ------------------------------------------------------------
# Standardize column names
df.columns = [c.strip().lower().replace(" ", "_").replace("/", "_") for c in df.columns]

# Rename to consistent names
df = df.rename(
    columns={
        "subject_id": "subject_id",
        "mri_id": "mri_id",
        "group": "group",
        "visit": "visit",
        "mr_delay": "mr_delay_days",
        "m_f": "sex",
        "hand": "hand",
        "age": "age",
        "educ": "educ",
        "ses": "ses",
        "mmse": "mmse",
        "cdr": "cdr",
        "etiv": "etiv",
        "nwbv": "nwbv",
        "asf": "asf",
    }
)

expected_cols = [
    "subject_id", "mri_id", "group", "visit", "mr_delay_days",
    "sex", "hand", "age", "educ", "ses", "mmse", "cdr", "etiv", "nwbv", "asf"
]
missing = [c for c in expected_cols if c not in df.columns]
if missing:
    raise ValueError(f"Missing expected columns: {missing}")

df["subject_id"] = df["subject_id"].astype(str)
for c in ["visit", "mr_delay_days", "age", "educ", "ses", "mmse", "cdr", "etiv", "nwbv", "asf"]:
    df[c] = pd.to_numeric(df[c], errors="coerce")

df["mr_delay_years"] = df["mr_delay_days"] / 365.25

# ------------------------------------------------------------
# 3. Missingness inspection
# ------------------------------------------------------------
print("\nFraction of missing values per column:")
missing_frac = df.isna().mean().sort_values(ascending=False)
display(missing_frac)

plt.figure(figsize=(9, 4))
missing_frac.plot(kind="bar")
plt.ylabel("Fraction missing")
plt.title("Missing data by column")
plt.tight_layout()
plt.show()

# ------------------------------------------------------------
# 4. Baseline-level dataset (MR_DELAY == 0)
#    Label: 1 = Demented or Converted, 0 = Nondemented
# ------------------------------------------------------------
baseline = df[df["mr_delay_days"] == 0].copy()
print("\nBaseline shape:", baseline.shape)
print("Baseline groups:")
display(baseline["group"].value_counts())

label_map = {"Nondemented": 0, "Demented": 1, "Converted": 1}
baseline["dementia_label"] = baseline["group"].map(label_map)
baseline = baseline.dropna(subset=["dementia_label"])
baseline["dementia_label"] = baseline["dementia_label"].astype(int)

print("\nDementia label distribution at baseline (0 = nondemented, 1 = demented/converted):")
display(baseline["dementia_label"].value_counts())

plt.figure(figsize=(4, 4))
sns.countplot(data=baseline, x="dementia_label")
plt.xticks([0, 1], ["Nondemented", "Demented/Converted"])
plt.title("Baseline dementia status")
plt.tight_layout()
plt.show()

# ------------------------------------------------------------
# 5. Longitudinal summary per subject (descriptive only)
# ------------------------------------------------------------
long_summary = (
    df[df["subject_id"].isin(baseline["subject_id"])]
    .groupby("subject_id")
    .agg(
        n_visits=("visit", "max"),
        max_delay_days=("mr_delay_days", "max"),
        max_cdr=("cdr", "max"),
        min_mmse=("mmse", "min"),
    )
    .reset_index()
)
long_summary["followup_years"] = long_summary["max_delay_days"] / 365.25

baseline = baseline.merge(long_summary, on="subject_id", how="left")

print("\nFollow-up years summary (all baseline subjects):")
display(baseline["followup_years"].describe())

plt.figure(figsize=(6, 4))
sns.histplot(baseline["followup_years"], bins=15, kde=True)
plt.xlabel("Follow-up years")
plt.title("Distribution of follow-up duration")
plt.tight_layout()
plt.show()

# ------------------------------------------------------------
# 6. Prepare combined features + label (NO CDR)
# ------------------------------------------------------------
# Encode sex
sex_map = {"F": 0, "M": 1, "Female": 0, "Male": 1}
baseline["sex_enc"] = baseline["sex"].map(sex_map)

label_col = "dementia_label"
y = baseline[label_col].astype(int).values

combined_features = ["age", "mmse", "educ", "ses", "nwbv", "etiv", "asf", "sex_enc"]

print("\nFinal modeling dataset size:", baseline.shape[0])
print("Combined feature set:", combined_features)
print("Nondemented:", int((y == 0).sum()),
      "| Demented/Converted:", int((y == 1).sum()))

# Correlation heatmap
corr_df = baseline[combined_features + [label_col]].copy()
plt.figure(figsize=(10, 8))
corr = corr_df.corr()
sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", vmin=-1, vmax=1)
plt.title("Correlation matrix – Combined features + dementia status")
plt.tight_layout()
plt.show()

# KDEs for key features
plt.figure(figsize=(14, 4))
for i, col in enumerate(["age", "mmse", "nwbv"], start=1):
    if col not in baseline.columns:
        continue
    plt.subplot(1, 3, i)
    sns.kdeplot(
        data=baseline,
        x=col,
        hue=label_col,
        common_norm=False,
        fill=True,
        alpha=0.4,
    )
    plt.title(col)
    plt.xlabel(col)
plt.suptitle("Baseline distributions by dementia status", y=1.05)
plt.tight_layout()
plt.show()

# ------------------------------------------------------------
# 7. Helper functions for thresholds and metrics
# ------------------------------------------------------------
def compute_metrics_at_threshold(name, y_true, proba, thr):
    """Compute metrics at a given threshold thr."""
    y_pred = (proba >= thr).astype(int)

    auc = roc_auc_score(y_true, proba)
    pr_auc = average_precision_score(y_true, proba)
    acc = accuracy_score(y_true, y_pred)
    sens = recall_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)

    # Always force 2x2 confusion by specifying labels=[0,1]
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    spec = tn / (tn + fp) if (tn + fp) > 0 else np.nan

    print(f"\n=== {name} (threshold = {thr:.3f}) ===")
    print(f"ROC-AUC: {auc:.3f}")
    print(f"PR-AUC:  {pr_auc:.3f}")
    print(f"Accuracy:        {acc:.3f}")
    print(f"Sensitivity:     {sens:.3f}")
    print(f"Specificity:     {spec:.3f}")
    print(f"Precision (PPV): {prec:.3f}")
    print(f"Confusion: TP={tp}, FP={fp}, TN={tn}, FN={fn}")

    return {
        "threshold": thr,
        "auc": auc,
        "pr_auc": pr_auc,
        "acc": acc,
        "sens": sens,
        "spec": spec,
        "prec": prec,
        "tp": tp,
        "fp": fp,
        "tn": tn,
        "fn": fn,
    }


def choose_threshold_roc_optimal(y_true, proba):
    """Choose threshold closest to (FPR=0, TPR=1) on the ROC curve."""
    fpr, tpr, thr = roc_curve(y_true, proba)
    # distance^2 to (0,1)
    dist2 = (fpr ** 2) + ((1 - tpr) ** 2)
    idx = int(np.argmin(dist2))
    return thr[idx]


# ------------------------------------------------------------
# 8. Simple train/test split evaluation (for intuition)
# ------------------------------------------------------------
X = baseline[combined_features].values

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=RANDOM_STATE
)

print("\nTrain size:", X_train.shape[0], "| Test size:", X_test.shape[0])

imputer = SimpleImputer(strategy="median")
X_train_imp = imputer.fit_transform(X_train)
X_test_imp = imputer.transform(X_test)

# Logistic Regression
log_reg = LogisticRegression(max_iter=1000, random_state=RANDOM_STATE)
log_reg.fit(X_train_imp, y_train)
proba_lr = log_reg.predict_proba(X_test_imp)[:, 1]

# RF
rf_model = RandomForestClassifier(
    n_estimators=300,
    max_depth=None,
    min_samples_split=4,
    min_samples_leaf=2,
    random_state=RANDOM_STATE,
    class_weight="balanced",
)
rf_model.fit(X_train_imp, y_train)
proba_rf = rf_model.predict_proba(X_test_imp)[:, 1]

# Thresholds for each model
for model_name, proba in [("Logistic Regression", proba_lr), ("Random Forest", proba_rf)]:
    # Fixed 0.5
    _ = compute_metrics_at_threshold(
        f"Combined – {model_name} (fixed 0.5)",
        y_test,
        proba,
        thr=0.5,
    )

    # ROC-optimal
    thr_opt = choose_threshold_roc_optimal(y_test, proba)
    _ = compute_metrics_at_threshold(
        f"Combined – {model_name} (ROC-optimal)",
        y_test,
        proba,
        thr=thr_opt,
    )

# ROC & PR curves for RF on test set
plt.figure(figsize=(6, 5))
fpr, tpr, _ = roc_curve(y_test, proba_rf)
auc_rf = roc_auc_score(y_test, proba_rf)
plt.plot(fpr, tpr, label=f"RF (AUC={auc_rf:.2f})")
plt.plot([0, 1], [0, 1], "k--", alpha=0.4)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC curve – Random Forest (combined features, test set)")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(6, 5))
prec, rec, _ = precision_recall_curve(y_test, proba_rf)
ap_rf = average_precision_score(y_test, proba_rf)
plt.plot(rec, prec, label=f"RF (AP={ap_rf:.2f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision–Recall curve – Random Forest (combined, test set)")
plt.legend()
plt.tight_layout()
plt.show()

# ------------------------------------------------------------
# 9. 5-fold cross-validation for RF (main performance)
# ------------------------------------------------------------
print("\n\n===== 5-fold cross-validation – Random Forest (combined features) =====")

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE)

cv_results = []

for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X, y), start=1):
    X_tr, X_val = X[train_idx], X[val_idx]
    y_tr, y_val = y[train_idx], y[val_idx]

    imputer_cv = SimpleImputer(strategy="median")
    X_tr_imp = imputer_cv.fit_transform(X_tr)
    X_val_imp = imputer_cv.transform(X_val)

    rf_cv = RandomForestClassifier(
        n_estimators=300,
        max_depth=None,
        min_samples_split=4,
        min_samples_leaf=2,
        random_state=RANDOM_STATE,
        class_weight="balanced",
    )
    rf_cv.fit(X_tr_imp, y_tr)
    proba_val = rf_cv.predict_proba(X_val_imp)[:, 1]

    # AUC / PR-AUC
    auc_val = roc_auc_score(y_val, proba_val)
    pr_auc_val = average_precision_score(y_val, proba_val)

    # Thresholds: 0.5 and ROC-optimal
    for thr_name, thr in [
        ("fixed_0.5", 0.5),
        ("roc_opt", choose_threshold_roc_optimal(y_val, proba_val)),
    ]:
        metrics = compute_metrics_at_threshold(
            f"Fold {fold_idx} – RF (combined, {thr_name})",
            y_val,
            proba_val,
            thr=thr,
        )
        metrics["fold"] = fold_idx
        metrics["thr_name"] = thr_name
        metrics["auc"] = auc_val  # override with CV-specific auc
        metrics["pr_auc"] = pr_auc_val
        cv_results.append(metrics)

cv_df = pd.DataFrame(cv_results)
display(cv_df.head())

print("\nCross-validation summary (Random Forest, combined features):")
summary_cv = (
    cv_df.groupby("thr_name")[["auc", "pr_auc", "acc", "sens", "spec", "prec"]]
    .agg(["mean", "std"])
)
display(summary_cv)

# ------------------------------------------------------------
# 10. SHAP explainability for final RF model (trained on full data)
# ------------------------------------------------------------
print("\nTraining final RF on full combined dataset for SHAP explanations...")
imputer_full = SimpleImputer(strategy="median")
X_full_imp = imputer_full.fit_transform(X)

rf_full = RandomForestClassifier(
    n_estimators=300,
    max_depth=None,
    min_samples_split=4,
    min_samples_leaf=2,
    random_state=RANDOM_STATE,
    class_weight="balanced",
)
rf_full.fit(X_full_imp, y)

try:
    shap.initjs()
    # Background sample
    bg_size = min(100, X_full_imp.shape[0])
    rng = np.random.RandomState(RANDOM_STATE)
    idx_bg = rng.choice(X_full_imp.shape[0], size=bg_size, replace=False)
    background = X_full_imp[idx_bg]

    explainer = shap.TreeExplainer(rf_full, data=background)
    shap_values = explainer.shap_values(X_full_imp)

    if isinstance(shap_values, list):
        shap_values_class1 = shap_values[1]
    else:
        shap_values_class1 = shap_values

    # Beeswarm summary plot
    shap.summary_plot(
        shap_values_class1,
        X_full_imp,
        feature_names=combined_features,
        show=False,
    )
    plt.title("SHAP Beeswarm – RF (combined features)")
    plt.tight_layout()
    plt.show()

    # Bar plot
    shap.summary_plot(
        shap_values_class1,
        X_full_imp,
        feature_names=combined_features,
        plot_type="bar",
        show=False,
    )
    plt.title("Mean |SHAP| Feature Importance – RF (combined features)")
    plt.tight_layout()
    plt.show()

except Exception as e:
    print("SHAP plotting encountered an error (plots skipped):", e)

# ------------------------------------------------------------
# 11. Write summary text file
# ------------------------------------------------------------
summary_lines = []
def add_line(s=""):
    summary_lines.append(s)

N = baseline.shape[0]
n_dem = int(y.sum())

add_line("OASIS-2 Baseline Dementia Classification – Combined Features (No CDR as Predictor)")
add_line("==============================================================================")
add_line(f"Total baseline subjects: {N}")
add_line(f"Demented/Converted:      {n_dem} ({n_dem / N * 100:.1f}%)")
add_line(f"Nondemented:             {N - n_dem} ({(N - n_dem) / N * 100:.1f}%)")
add_line("")
add_line("Combined feature set:")
add_line(", ".join(combined_features))
add_line("")

add_line("Random Forest – 5-fold cross-validation (combined features)")
add_line("Metrics reported as mean ± std over folds.")
add_line("Two thresholds considered: fixed 0.5 and ROC-optimal per fold.")
add_line("")

for thr_name in ["fixed_0.5", "roc_opt"]:
    sub = cv_df[cv_df["thr_name"] == thr_name]
    if sub.empty:
        continue
    auc_mean, auc_std = sub["auc"].mean(), sub["auc"].std()
    pr_mean, pr_std = sub["pr_auc"].mean(), sub["pr_auc"].std()
    acc_mean, acc_std = sub["acc"].mean(), sub["acc"].std()
    sens_mean, sens_std = sub["sens"].mean(), sub["sens"].std()
    spec_mean, spec_std = sub["spec"].mean(), sub["spec"].std()
    prec_mean, prec_std = sub["prec"].mean(), sub["prec"].std()

    add_line(f"Threshold strategy: {thr_name}")
    add_line(f"  AUC:    {auc_mean:.3f} ± {auc_std:.3f}")
    add_line(f"  PR-AUC: {pr_mean:.3f} ± {pr_std:.3f}")
    add_line(f"  Acc:    {acc_mean:.3f} ± {acc_std:.3f}")
    add_line(f"  Sens:   {sens_mean:.3f} ± {sens_std:.3f}")
    add_line(f"  Spec:   {spec_mean:.3f} ± {spec_std:.3f}")
    add_line(f"  PPV:    {prec_mean:.3f} ± {prec_std:.3f}")
    add_line("")

summary_path = "oasis_dementia_combined_summary.txt"
with open(summary_path, "w") as f:
    f.write("\n".join(summary_lines))

print(f"\nSummary written to {summary_path}")
print("\nAnalysis complete.")
Raw shape: (373, 15)
Columns: ['Subject ID', 'MRI ID', 'Group', 'Visit', 'MR Delay', 'M/F', 'Hand', 'Age', 'EDUC', 'SES', 'MMSE', 'CDR', 'eTIV', 'nWBV', 'ASF']
Subject ID MRI ID Group Visit MR Delay M/F Hand Age EDUC SES MMSE CDR eTIV nWBV ASF
0 OAS2_0001 OAS2_0001_MR1 Nondemented 1 0 M R 87 14 2.0 27.0 0.0 1986.550000 0.696106 0.883440
1 OAS2_0001 OAS2_0001_MR2 Nondemented 2 457 M R 88 14 2.0 30.0 0.0 2004.479526 0.681062 0.875539
2 OAS2_0002 OAS2_0002_MR1 Demented 1 0 M R 75 12 NaN 23.0 0.5 1678.290000 0.736336 1.045710
3 OAS2_0002 OAS2_0002_MR2 Demented 2 560 M R 76 12 NaN 28.0 0.5 1737.620000 0.713402 1.010000
4 OAS2_0002 OAS2_0002_MR3 Demented 3 1895 M R 80 12 NaN 22.0 0.5 1697.911134 0.701236 1.033623
Fraction of missing values per column:
0
ses 0.050938
mmse 0.005362
subject_id 0.000000
mri_id 0.000000
mr_delay_days 0.000000
sex 0.000000
group 0.000000
visit 0.000000
age 0.000000
hand 0.000000
educ 0.000000
cdr 0.000000
etiv 0.000000
nwbv 0.000000
asf 0.000000
mr_delay_years 0.000000

No description has been provided for this image
Baseline shape: (150, 16)
Baseline groups:
count
group
Nondemented 72
Demented 64
Converted 14

Dementia label distribution at baseline (0 = nondemented, 1 = demented/converted):
count
dementia_label
1 78
0 72

No description has been provided for this image
Follow-up years summary (all baseline subjects):
followup_years
count 150.000000
mean 2.925521
std 1.477717
min 0.999316
25% 1.774812
50% 2.316222
75% 3.919918
max 7.225188

No description has been provided for this image
Final modeling dataset size: 150
Combined feature set: ['age', 'mmse', 'educ', 'ses', 'nwbv', 'etiv', 'asf', 'sex_enc']
Nondemented: 72 | Demented/Converted: 78
No description has been provided for this image
No description has been provided for this image
Train size: 120 | Test size: 30

=== Combined – Logistic Regression (fixed 0.5) (threshold = 0.500) ===
ROC-AUC: 0.705
PR-AUC:  0.798
Accuracy:        0.600
Sensitivity:     0.500
Specificity:     0.714
Precision (PPV): 0.667
Confusion: TP=8, FP=4, TN=10, FN=8

=== Combined – Logistic Regression (ROC-optimal) (threshold = 0.438) ===
ROC-AUC: 0.705
PR-AUC:  0.798
Accuracy:        0.633
Sensitivity:     0.625
Specificity:     0.643
Precision (PPV): 0.667
Confusion: TP=10, FP=5, TN=9, FN=6

=== Combined – Random Forest (fixed 0.5) (threshold = 0.500) ===
ROC-AUC: 0.714
PR-AUC:  0.812
Accuracy:        0.700
Sensitivity:     0.562
Specificity:     0.857
Precision (PPV): 0.818
Confusion: TP=9, FP=2, TN=12, FN=7

=== Combined – Random Forest (ROC-optimal) (threshold = 0.645) ===
ROC-AUC: 0.714
PR-AUC:  0.812
Accuracy:        0.733
Sensitivity:     0.562
Specificity:     0.929
Precision (PPV): 0.900
Confusion: TP=9, FP=1, TN=13, FN=7
No description has been provided for this image
No description has been provided for this image

===== 5-fold cross-validation – Random Forest (combined features) =====

=== Fold 1 – RF (combined, fixed_0.5) (threshold = 0.500) ===
ROC-AUC: 0.836
PR-AUC:  0.893
Accuracy:        0.833
Sensitivity:     0.800
Specificity:     0.867
Precision (PPV): 0.857
Confusion: TP=12, FP=2, TN=13, FN=3

=== Fold 1 – RF (combined, roc_opt) (threshold = 0.550) ===
ROC-AUC: 0.836
PR-AUC:  0.893
Accuracy:        0.833
Sensitivity:     0.800
Specificity:     0.867
Precision (PPV): 0.857
Confusion: TP=12, FP=2, TN=13, FN=3

=== Fold 2 – RF (combined, fixed_0.5) (threshold = 0.500) ===
ROC-AUC: 0.849
PR-AUC:  0.816
Accuracy:        0.800
Sensitivity:     0.800
Specificity:     0.800
Precision (PPV): 0.800
Confusion: TP=12, FP=3, TN=12, FN=3

=== Fold 2 – RF (combined, roc_opt) (threshold = 0.618) ===
ROC-AUC: 0.849
PR-AUC:  0.816
Accuracy:        0.833
Sensitivity:     0.733
Specificity:     0.933
Precision (PPV): 0.917
Confusion: TP=11, FP=1, TN=14, FN=4

=== Fold 3 – RF (combined, fixed_0.5) (threshold = 0.500) ===
ROC-AUC: 0.808
PR-AUC:  0.860
Accuracy:        0.733
Sensitivity:     0.562
Specificity:     0.929
Precision (PPV): 0.900
Confusion: TP=9, FP=1, TN=13, FN=7

=== Fold 3 – RF (combined, roc_opt) (threshold = 0.449) ===
ROC-AUC: 0.808
PR-AUC:  0.860
Accuracy:        0.733
Sensitivity:     0.688
Specificity:     0.786
Precision (PPV): 0.786
Confusion: TP=11, FP=3, TN=11, FN=5

=== Fold 4 – RF (combined, fixed_0.5) (threshold = 0.500) ===
ROC-AUC: 0.857
PR-AUC:  0.899
Accuracy:        0.700
Sensitivity:     0.688
Specificity:     0.714
Precision (PPV): 0.733
Confusion: TP=11, FP=4, TN=10, FN=5

=== Fold 4 – RF (combined, roc_opt) (threshold = 0.431) ===
ROC-AUC: 0.857
PR-AUC:  0.899
Accuracy:        0.833
Sensitivity:     0.938
Specificity:     0.714
Precision (PPV): 0.789
Confusion: TP=15, FP=4, TN=10, FN=1

=== Fold 5 – RF (combined, fixed_0.5) (threshold = 0.500) ===
ROC-AUC: 0.795
PR-AUC:  0.868
Accuracy:        0.767
Sensitivity:     0.625
Specificity:     0.929
Precision (PPV): 0.909
Confusion: TP=10, FP=1, TN=13, FN=6

=== Fold 5 – RF (combined, roc_opt) (threshold = 0.516) ===
ROC-AUC: 0.795
PR-AUC:  0.868
Accuracy:        0.800
Sensitivity:     0.625
Specificity:     1.000
Precision (PPV): 1.000
Confusion: TP=10, FP=0, TN=14, FN=6
threshold auc pr_auc acc sens spec prec tp fp tn fn fold thr_name
0 0.500000 0.835556 0.892860 0.833333 0.800000 0.866667 0.857143 12 2 13 3 1 fixed_0.5
1 0.549905 0.835556 0.892860 0.833333 0.800000 0.866667 0.857143 12 2 13 3 1 roc_opt
2 0.500000 0.848889 0.815895 0.800000 0.800000 0.800000 0.800000 12 3 12 3 2 fixed_0.5
3 0.618354 0.848889 0.815895 0.833333 0.733333 0.933333 0.916667 11 1 14 4 2 roc_opt
4 0.500000 0.808036 0.860377 0.733333 0.562500 0.928571 0.900000 9 1 13 7 3 fixed_0.5
Cross-validation summary (Random Forest, combined features):
auc pr_auc acc sens spec prec
mean std mean std mean std mean std mean std mean std
thr_name
fixed_0.5 0.828853 0.026694 0.867316 0.033049 0.766667 0.052705 0.695000 0.105549 0.847619 0.091535 0.839913 0.073561
roc_opt 0.828853 0.026694 0.867316 0.033049 0.806667 0.043461 0.756667 0.119628 0.860000 0.113769 0.869799 0.090597
Training final RF on full combined dataset for SHAP explanations...
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Summary written to oasis_dementia_combined_summary.txt

Analysis complete.
In [22]:
# ===========================
# MINIMAL PREP CELL FOR VISUALIZATIONS (EXCEL VERSION)
# ===========================

import pandas as pd
import numpy as np
import os

# ---- EDIT YOUR FILE NAME EXACTLY AS UPLOADED ----
csv_path = "oasis_longitudinal_demographics-8d83e569fa2e2d30.xlsx"

# Read Excel, not CSV
raw = pd.read_excel(csv_path)

# Clean column names
raw.columns = [c.strip().lower().replace(" ", "_") for c in raw.columns]

# Fix sex column name
if "m_f" not in raw.columns and "m/f" in raw.columns:
    raw["m_f"] = raw["m/f"]

# Remove missing rows for required ID columns
raw = raw.dropna(subset=["subject_id", "group"])

# Baseline = Visit 1
baseline = raw[raw["visit"] == 1].copy()

# Encode sex (M/F)
sex_map = {"M": 1, "F": 0, "m": 1, "f": 0}
baseline["sex_enc"] = baseline["m_f"].map(sex_map)

# Dementia label map
label_map = {"Nondemented": 0, "Demented": 1, "Converted": 1}
baseline["dementia_label"] = baseline["group"].map(label_map)

# Feature set for visualizations
feature_cols = ["age", "mmse", "educ", "ses", "nwbv", "etiv", "asf", "sex_enc"]

# Final modeling dataframe
df_model = baseline[feature_cols + ["dementia_label"]].dropna().copy()

print("df_model created:", df_model.shape)
print("feature_cols:", feature_cols)
print(df_model.head(10))
df_model created: (142, 9)
feature_cols: ['age', 'mmse', 'educ', 'ses', 'nwbv', 'etiv', 'asf', 'sex_enc']
    age  mmse  educ  ses      nwbv     etiv      asf  sex_enc  dementia_label
0    87  27.0    14  2.0  0.696106  1986.55  0.88344        1               0
5    88  28.0    18  3.0  0.709512  1215.33  1.44406        0               0
7    80  28.0    12  4.0  0.711502  1688.58  1.03933        1               0
13   93  30.0    14  2.0  0.697599  1271.51  1.38024        0               0
15   68  27.0    12  2.0  0.806315  1456.60  1.20486        1               1
17   66  30.0    12  3.0  0.768708  1446.66  1.21314        0               1
19   78  29.0    16  2.0  0.747875  1333.37  1.31621        0               0
22   81  30.0    12  4.0  0.715019  1229.72  1.42716        0               0
25   76  21.0    16  3.0  0.696770  1601.89  1.09558        1               1
27   88  25.0     8  4.0  0.659691  1650.60  1.06325        1               1
In [23]:
# ============================================================
# FULL VISUALIZATION SUITE FOR OASIS-2 BASELINE df_model
# ============================================================

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os

# Create folder for figures
os.makedirs("figures", exist_ok=True)

# Pretty plotting style
sns.set(style="whitegrid", font_scale=1.3)


# ============================================================
# 1) CLASS BALANCE PLOT
# ============================================================
def plot_class_balance(df):
    plt.figure(figsize=(7,5))
    sns.countplot(x="dementia_label", data=df, palette="viridis")
    plt.title("Class Distribution: Dementia vs Non-Dementia")
    plt.xticks([0,1], ["Non-Demented", "Demented/Converted"])
    plt.ylabel("Count")
    plt.xlabel("")
    plt.tight_layout()
    plt.savefig("figures/class_balance.png")
    plt.show()

plot_class_balance(df_model)


# ============================================================
# 2) FEATURE HISTOGRAMS + KDE PER CLASS
# ============================================================
def plot_feature_histograms(df, features):
    for f in features:
        plt.figure(figsize=(8,5))
        sns.histplot(data=df, x=f, hue="dementia_label", kde=True,
                     palette="viridis", element="step")
        plt.title(f"Distribution of {f} by Dementia Status")
        plt.tight_layout()
        plt.savefig(f"figures/hist_{f}.png")
        plt.show()

plot_feature_histograms(df_model, feature_cols)


# ============================================================
# 3) BOXPLOTS PER FEATURE
# ============================================================
def plot_boxplots(df, features):
    for f in features:
        plt.figure(figsize=(7,5))
        sns.boxplot(x="dementia_label", y=f, data=df, palette="viridis")
        plt.title(f"{f} by Dementia Status")
        plt.xticks([0,1], ["Non-Demented", "Demented"])
        plt.tight_layout()
        plt.savefig(f"figures/box_{f}.png")
        plt.show()

plot_boxplots(df_model, feature_cols)


# ============================================================
# 4) CORRELATION HEATMAP
# ============================================================
def plot_corr_heatmap(df):
    corr = df[feature_cols + ["dementia_label"]].corr()
    plt.figure(figsize=(12,10))
    sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f")
    plt.title("Correlation Heatmap (Baseline Features + Target)")
    plt.tight_layout()
    plt.savefig("figures/correlation_heatmap.png")
    plt.show()

plot_corr_heatmap(df_model)


# ============================================================
# 5) PAIRPLOT (SEABORN)
# ============================================================
sns.pairplot(df_model[["age","mmse","nwbv","etiv","dementia_label"]],
             hue="dementia_label", palette="viridis", diag_kind="kde")
plt.savefig("figures/pairplot.png")
plt.show()


# ============================================================
# 6) VIOLIN PLOTS — DISTRIBUTION SHAPE + SPREAD
# ============================================================
def plot_violins(df, features):
    for f in features:
        plt.figure(figsize=(8,5))
        sns.violinplot(x="dementia_label", y=f, data=df, palette="viridis")
        plt.title(f"Violin Plot: {f} by Dementia Status")
        plt.xticks([0,1], ["Non-Demented", "Demented"])
        plt.tight_layout()
        plt.savefig(f"figures/violin_{f}.png")
        plt.show()

plot_violins(df_model, feature_cols)


# ============================================================
# 7) AGE vs BRAIN VOLUME — KEY BIOMARKER INTERACTION
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(data=df_model, x="age", y="nwbv",
                hue="dementia_label", palette="viridis", s=80)
plt.title("Age vs Normalized Whole-Brain Volume (nWBV)")
plt.tight_layout()
plt.savefig("figures/scatter_age_nwbv.png")
plt.show()


# ============================================================
# 8) MMSE vs BRAIN VOLUME — Cognitive vs Structural Decline
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(data=df_model, x="mmse", y="nwbv",
                hue="dementia_label", palette="viridis", s=80)
plt.title("MMSE vs Brain Volume (nWBV)")
plt.tight_layout()
plt.savefig("figures/scatter_mmse_nwbv.png")
plt.show()


# ============================================================
# 9) EDUCATION vs SES (Sociodemographic patterns)
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(data=df_model, x="educ", y="ses",
                hue="dementia_label", palette="viridis", s=80)
plt.title("Education vs SES by Dementia Status")
plt.tight_layout()
plt.savefig("figures/scatter_educ_ses.png")
plt.show()


# ============================================================
# 10) eTIV vs nWBV — True atrophy visualization
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(data=df_model, x="etiv", y="nwbv",
                hue="dementia_label", palette="viridis", s=80)
plt.title("eTIV vs nWBV (Atrophy Indicator)")
plt.tight_layout()
plt.savefig("figures/scatter_etiv_nwbv.png")
plt.show()


print("🎉 All visualizations generated and saved in /figures/")
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
🎉 All visualizations generated and saved in /figures/
In [24]:
# ============================================================
# BEAUTIFUL MULTI-COLOR VISUALIZATION SUITE (IEEE-READY)
# ============================================================

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os

os.makedirs("figures", exist_ok=True)

# Different palettes for variety
palette1 = "viridis"
palette2 = "crest"
palette3 = "rocket"
palette4 = "flare"
palette5 = "magma"
palette6 = "coolwarm"
palette7 = "Spectral"
palette8 = "cubehelix"
palette9 = "icefire"
palette10 = "Set2"

sns.set_context("talk", font_scale=1.15)


# ============================================================
# 1) CLASS BALANCE
# ============================================================
sns.set_style("whitegrid")
plt.figure(figsize=(7,5))
sns.countplot(x="dementia_label", data=df_model, palette=palette7)
plt.title("Class Distribution", fontsize=18)
plt.xticks([0,1], ["Non-Demented", "Demented"], fontsize=14)
plt.ylabel("Count")
plt.tight_layout()
plt.savefig("figures/01_class_balance.png")
plt.show()


# ============================================================
# 2) HISTOGRAMS (Different palette each feature)
# ============================================================
palettes_cycle = [palette1, palette2, palette3, palette4, palette5,
                  palette6, palette7, palette8]

sns.set_style("ticks")

for i, f in enumerate(feature_cols):
    plt.figure(figsize=(8,5))
    sns.histplot(
        df_model, x=f, hue="dementia_label",
        kde=True, element="step", alpha=0.5,
        palette=palettes_cycle[i % len(palettes_cycle)]
    )
    plt.title(f"Distribution of {f}", fontsize=18)
    plt.xlabel(f)
    plt.tight_layout()
    plt.savefig(f"figures/02_hist_{f}.png")
    plt.show()


# ============================================================
# 3) BOXPLOTS
# ============================================================
sns.set_style("whitegrid")
for i, f in enumerate(feature_cols):
    plt.figure(figsize=(7,5))
    sns.boxplot(
        x="dementia_label", y=f, data=df_model,
        palette=palette10
    )
    plt.title(f"{f} Comparison", fontsize=18)
    plt.xticks([0,1], ["Non-Demented", "Demented"], fontsize=14)
    plt.tight_layout()
    plt.savefig(f"figures/03_box_{f}.png")
    plt.show()


# ============================================================
# 4) VIOLIN PLOTS (High quality)
# ============================================================
sns.set_style("darkgrid")

for f in feature_cols:
    plt.figure(figsize=(8,5))
    sns.violinplot(
        data=df_model,
        x="dementia_label", y=f,
        palette=palette3, inner="quartile"
    )
    plt.title(f"Violin Plot: {f}", fontsize=18)
    plt.xticks([0,1], ["Non-Demented", "Demented"], fontsize=14)
    plt.tight_layout()
    plt.savefig(f"figures/04_violin_{f}.png")
    plt.show()


# ============================================================
# 5) CORRELATION HEATMAP (advanced)
# ============================================================
sns.set_style("white")

plt.figure(figsize=(12,10))
corr = df_model.corr()
sns.heatmap(
    corr, annot=True, cmap=palette9,
    linewidths=0.5, linecolor="white",
    square=True, cbar_kws={"shrink": 0.8}
)
plt.title("Correlation Matrix (Baseline Features)", fontsize=20)
plt.tight_layout()
plt.savefig("figures/05_corr_heatmap.png")
plt.show()


# ============================================================
# 6) SCATTERPLOT: Age vs Brain Volume
# ============================================================
sns.set_style("ticks")

plt.figure(figsize=(8,6))
sns.scatterplot(
    data=df_model,
    x="age", y="nwbv",
    hue="dementia_label",
    palette=palette6, s=120, alpha=0.8, edgecolor="black"
)
plt.title("Age vs nWBV", fontsize=18)
plt.tight_layout()
plt.savefig("figures/06_scatter_age_nwbv.png")
plt.show()


# ============================================================
# 7) SCATTERPLOT: MMSE vs nWBV
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(
    data=df_model,
    x="mmse", y="nwbv",
    hue="dementia_label",
    palette=palette8, s=120, alpha=0.85, edgecolor="black"
)
plt.title("MMSE vs Brain Volume", fontsize=18)
plt.tight_layout()
plt.savefig("figures/07_scatter_mmse_nwbv.png")
plt.show()


# ============================================================
# 8) EDUC vs SES
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(
    data=df_model, x="educ", y="ses",
    hue="dementia_label", palette=palette4,
    s=120, edgecolor="black", alpha=0.8
)
plt.title("Education vs SES", fontsize=18)
plt.tight_layout()
plt.savefig("figures/08_scatter_educ_ses.png")
plt.show()


# ============================================================
# 9) eTIV vs nWBV (atrophy)
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(
    data=df_model, x="etiv", y="nwbv",
    hue="dementia_label", palette=palette2,
    s=120, edgecolor="black", alpha=0.8
)
plt.title("eTIV vs Brain Volume (Atrophy)", fontsize=18)
plt.tight_layout()
plt.savefig("figures/09_scatter_etiv_nwbv.png")
plt.show()


print("🎉 ALL MULTI-COLOR VISUALIZATIONS GENERATED SUCCESSFULLY!")
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
🎉 ALL MULTI-COLOR VISUALIZATIONS GENERATED SUCCESSFULLY!